# CauF-VAE: Causal Disentangled Representation Learning with VAE and Causal Flows
This repository contains code to run and reproduce experiments presented in [*CauF-VAE: Causal Disentangled Representation Learning with VAE and Causal Flows*](https://openreview.net/attachment?id=J9qCDJa3CA&name=pdf)

## Dependencies
This project was tested with the following versions:

- python 3.7.13
- numpy 1.21.5
- pytorch 1.21.1
- torchvision 0.13.1
- pyyaml  6.0
- scikit-learn 1.0.2
- scipy 1.7.3
- matplotlib 3.5.2
- seaborn 0.12.2
- tqdm 4.42.1
- networkx 2.6.3
- pandas 1.3.5

## Model
![alt text](./results/model_structure.png)

## Datasets
- [CelebA](http://mmlab.ie.cuhk.edu.hk/projects/CelebA.html): Please download the celeba dataset and place it in the "./src/CauF-VAE_CelebA/causal_data" directory.
- Pendulum: the dataset is generated based on the source code from [this repository](https://github.com/huawei-noah/trustworthyAI/blob/master/research/CausalVAE/causal_data/pendulum.py). You can access the dataset in the directory "./src/CauF-VAE_pendulum/causal_data", or generate the pendulum data using the following command: 
```
 python ./src/CauF-VAE_pendulum/causal_data/pendulum.py
 ```

## Run

- Run CauF-VAE on CelebA:
```
python ./src/CauF-VAE_CelebA/run_celeba_smile.py 
```
or run ./src/CauF-VAE_CelebA/run_celeba_smile.ipynb using jupyter. 

Similarly for the CelebA(Attractive) dataset.

- Run CauF-VAE on Pendulum:
```
python ./src/CauF-VAE_pendulum/run_pendulum.py
```
or run./src/CauF-VAE_pendulum/run_pendulum.ipynb using jupyter.


### Help
Important arguments:
```
Dataset:
  --dataset          		name of the data
  --data_dir          		directory of the dataset
  
General settings:
                --dim: latent dimension
                --dim1: number of dimensions before factors
                --dim2: number of interested factors
                --dim3: number of dimensions after factors
                --labels {smile, age, pend}: name of the underlying structure
                     
Supervised regularizer:
  --sup_coef          		coefficient of the supervised regularizer
  --sup_prop          		proportion of supervised labels
  --sup_type {ce, l2}		type of the supervised loss

```

### Output
The following folders will be generated:
--**checkpoints**: model storage during training
--**figs**: real images and reconstructed_image during training
--**results**: traverse images and intervened images


## Results

CelebA:
![alt text](./results/CauF-VAE_smile_traverse.png)
![alt text](./results/CauF-VAE_smile_intervene.png)

Pendulum:
![alt text](./results/CauF-VAE_pendulum_traverse.png)
![alt text](./results/CauF-VAE_pendulum_intervene.png)




